package org.zjvis.datascience.spark.algorithm;

import breeze.linalg.DenseMatrix;
import com.google.common.base.Joiner;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.tsne.impl.BHTSNE;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary;
import org.apache.spark.mllib.stat.Statistics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;
import scala.collection.JavaConversions;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


/**
 * @description Spark-TSNE 数据降维算子
 * @date 2021-12-23
 */
public class TSNEAlgorithm extends BaseAlgorithm {

    private String[] featureCols;

    private String targetTable;

    private String sourceTable;

    private int k;

    private int maxIter;

    public TSNEAlgorithm(SparkSession sparkSession) {
        super(sparkSession);
    }

    public boolean parseParams(String[] args) {
        super.parseParams(args);

        options.addOption(OptionHelper.getSpecOption("source"));
        options.addOption(OptionHelper.getSpecOption("target"));
        options.addOption(OptionHelper.getSpecOption("featureCols"));
        options.addOption(
                OptionHelper.getSpecOption("k", "keyComponent", "principal components", true));
        options.addOption(OptionHelper.getSpecOption("m", "maxIter", "max iteration", true));

        boolean flag = this.initCommandLine(args);

        if (!flag) {
            logger.error("initCommandLine fail!!!");
            return false;
        }
        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }
        if (cmd.hasOption("k") || cmd.hasOption("keyComponent")) {
            k = Integer.parseInt(cmd.getOptionValue("keyComponent"));
        }
        if (cmd.hasOption("m") || cmd.hasOption("maxIter")) {
            maxIter = Integer.parseInt(cmd.getOptionValue("maxIter"));
        }

        return true;
    }

    public boolean beginAlgorithm() {
        Dataset<Row> dataset = null;
        if (isHive) {
            String sql = UtilTool.buildSelectSql(featureCols, sourceTable, idCol);
            dataset = sparkSession.sql(sql);
        } else {
            dataset = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol)
                    .select(JavaConversions
                            .asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
        }
        VectorAssembler assembler = new VectorAssembler()
                .setHandleInvalid("skip")
                .setInputCols(featureCols)
                .setOutputCol("features");
        Dataset<Row> transDF = assembler.transform(dataset).select(idCol, "features")
                .persist(StorageLevel.MEMORY_AND_DISK());

        logger.debug("transDF.take(5) -> " + transDF.takeAsList(5));
        JavaRDD<Vector> rdd = transDF.select("features").toJavaRDD().flatMap(x -> {
            Vector vector = Vectors.fromML((DenseVector) x.get(0));
            List<Vector> result = new ArrayList<>();
            for (double v : vector.toArray()) {
                Vector tmp = Vectors.dense(new double[]{v});
                result.add(tmp);
            }
            return result.iterator();
        }).persist(StorageLevel.MEMORY_AND_DISK());
        MultivariateStatisticalSummary summary = Statistics.colStats(rdd.rdd());
        logger.debug("rdd.take(5) -> " + rdd.take(5));
        double mean = summary.mean().toArray()[0];

        double stdvar = Math.sqrt(summary.variance().toArray()[0]);
        logger.debug("mean={}", mean);
        logger.debug("stdvar={}", stdvar);

        RDD<Vector> featureRdd = transDF.select("features").toJavaRDD().map(x -> {
            double[] vector = Vectors.fromML((DenseVector) x.get(0)).toArray();
            int length = vector.length;
            for (int i = 0; i < length; ++i) {
                vector[i] = (vector[i] - mean) / stdvar;
            }
            return Vectors.dense(vector);
        }).rdd().persist(StorageLevel.MEMORY_AND_DISK());

        int pcaDim = Math.max((int) (featureCols.length * 0.9), k);

        RowMatrix matrix = new RowMatrix(featureRdd);

        RowMatrix pcaMatrix = matrix.multiply(matrix.computePrincipalComponents(pcaDim));

        DenseMatrix denseMatrix =
//                BHTSNE.tsne(pcaMatrix, k, maxIter, 30, 0.5, (i)-> (int)i % 10 == 0, null, 123456L);
                BHTSNE.tsne(pcaMatrix, k, maxIter, 30, 0.5, null, null, 123456L);

        List<Row> result = new ArrayList<>();
        int rows = denseMatrix.rows();
        for (int i = 0; i < rows; ++i) {
            String[] values = new String[k];
            for (int j = 0; j < k; ++j) {
                values[j] = Double.toString((Double) denseMatrix.valueAt(i, j));
            }
            result.add(RowFactory.create(Joiner.on(",").join(values)));
        }
        StructType tmpSchema = DataTypes.createStructType(new StructField[]{
                DataTypes.createStructField("features", DataTypes.StringType, true)});

        JavaRDD<Row> resultDf = sparkSession.createDataFrame(result, tmpSchema).toJavaRDD()
                .map(row -> {
                    String[] tmps = row.getString(0).split(",");
                    List<Double> vector = new ArrayList<>();
                    for (String e : tmps) {
                        vector.add(Double.parseDouble(e));
                    }
                    return RowFactory.create(vector.toArray());
                });

        StructType structType = UtilTool.createSchema(new HashMap<>(), new ArrayList<>(), k);
        Dataset<Row> df = sparkSession.createDataFrame(resultDf, structType);

        if (isHive) {
            this.saveHiveTable(df, targetTable);
        } else {
            UtilTool.saveGreenplumTable(df, targetTable);
        }

        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        outputTables.add(targetTable);

        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);

        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);

        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        transDF.unpersist();
        rdd.unpersist();
        featureRdd.unpersist(false);
        return true;
    }
}
